{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../../docs/images/DSPy8.png\" alt=\"DSPy7 Image\" height=\"150\"/>\n",
    "\n",
    "## **DSPy Assertions**: Asserting Computational Constraints on Foundation Models\n",
    "\n",
    "### **QuizGen**: Generating multiple choice quiz questions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[<img align=\"center\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" />](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/examples/quiz/quiz_assertions.ipynb)\n",
    "\n",
    "\n",
    "This notebook highlights an example of [**DSPy Assertions**](https://dspy-docs.vercel.app/docs/building-blocks/assertions), allowing for declaration of computational constraints within DSPy programs. \n",
    "\n",
    "\n",
    "This notebook builds upon the foundational concepts of the **DSPy** framework. Prerequisites of following this notebook is having gone through the [DSPy tutorial](../../intro.ipynb), the [**DSPy Assertions documentation**](https://dspy-docs.vercel.app/docs/building-blocks/assertions) and the introductory DSPy Assertions [tutorial on LongFormQA](../longformqa/longformqa_assertions.ipynb).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://huggingface.co/arnavs11/DSPy_QuizGen_Cache\n",
    "%cd DSPy_QuizGen_Cache/\n",
    "!git checkout master\n",
    "%cd ..\n",
    "import os\n",
    "repo_clone_path = '/content/DSPy_QuizGen_Cache'\n",
    "\n",
    "# Check if '/content' is writable\n",
    "if not os.access('/content', os.W_OK):\n",
    "    # If '/content' is not writable, choose an alternative directory\n",
    "    # Example: using a directory relative to the current working directory\n",
    "    repo_clone_path = os.path.join(os.getcwd(), 'DSPy_QuizGen_Cache')\n",
    "\n",
    "# Set up the cache for this notebook\n",
    "os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = repo_clone_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import os\n",
    "import regex as re\n",
    "import json\n",
    "\n",
    "try: # When on google Colab, let's clone the notebook so we download the cache.\n",
    "    import google.colab\n",
    "    repo_path = 'dspy'\n",
    "    \n",
    "    !git -C $repo_path pull origin || git clone https://github.com/stanfordnlp/dspy $repo_path\n",
    "except:\n",
    "    repo_path = '.'\n",
    "\n",
    "if repo_path not in sys.path:\n",
    "    sys.path.append(repo_path)\n",
    "\n",
    "\n",
    "import pkg_resources # Install the package if it's not installed\n",
    "if not \"dspy-ai\" in {pkg.key for pkg in pkg_resources.working_set}:\n",
    "    !pip install -U pip\n",
    "    !pip install dspy-ai\n",
    "    !pip install openai~=0.28.1\n",
    "    !pip install -e $repo_path\n",
    "\n",
    "import dspy\n",
    "from dspy.predict import Retry\n",
    "from dspy.datasets import HotPotQA\n",
    "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n",
    "from dspy.evaluate.evaluate import Evaluate\n",
    "from dspy.primitives.assertions import assert_transform_module, backtrack_handler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
    "dspy.settings.configure(rm=colbertv2_wiki17_abstracts)\n",
    "turbo = dspy.OpenAI(model='gpt-3.5-turbo-0613', max_tokens=500)\n",
    "dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0, keep_details=True)\n",
    "trainset = [x.with_inputs('question', 'answer') for x in dataset.train]\n",
    "devset = [x.with_inputs('question', 'answer') for x in dataset.dev]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3] QuizGen\n",
    "\n",
    "Let's introduce a new task: QuizGen. \n",
    "\n",
    "QuizGen takes HotPotQA data points and turns them into multiple choice quiz questions with the corresponding options. Each set of options for the question is produced in a JSON key-value pair format. For this case, we specify the generation of 4 choices."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this program, we aim to generate quiz choices that adhere to the following guidelines:\n",
    "1. The generated choices are in a JSON format.\n",
    "2. The generated choices include the correct answer.\n",
    "3. The generated choices include plausible distractor options besides the correct answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GenerateAnswerChoices(dspy.Signature):\n",
    "    \"\"\"Generate answer choices in JSON format that include the correct answer and plausible distractors for the specified question.\"\"\"\n",
    "    question = dspy.InputField()\n",
    "    correct_answer = dspy.InputField()\n",
    "    number_of_choices = dspy.InputField()\n",
    "    answer_choices = dspy.OutputField(desc='JSON key-value pairs')\n",
    "\n",
    "class QuizAnswerGenerator(dspy.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.generate_choices = dspy.ChainOfThought(GenerateAnswerChoices)\n",
    "\n",
    "    def forward(self, question, answer):\n",
    "        choices = self.generate_choices(question=question, correct_answer=answer, number_of_choices=number_of_choices).answer_choices\n",
    "        return dspy.Prediction(choices = choices)\n",
    "\n",
    "number_of_choices = '4'\n",
    "quiz_generator = QuizAnswerGenerator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4] Evaluation - Intrinsic and Extrinsic\n",
    "\n",
    "#### Intrinsic Metrics: passing internal computational constraints is the goal \n",
    "\n",
    "**Valid Formatting** - The outputted answer choices should be in JSON format which is verified after parsing the key-value pairs.\n",
    "\n",
    "**Correct Answer Inclusion** - This is a general check to ensure the generated quiz choices actually include the correct answer to the question.\n",
    "\n",
    "**Plausible Distractors** - This validation is to check that the generated choices include distractor answer options that are reasonable options as answers to the question. We define and call another **DSPy** program: ``Predict`` on ``AssessQuizChoices``, relying on the same LM to answer the question: `\"Are the distractors in the answer choices plausible and not easily identifiable as incorrect?\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_checker(choice_string):\n",
    "    try:\n",
    "        choices = json.loads(choice_string)\n",
    "        if isinstance(choices, dict) and all(isinstance(key, str) and isinstance(value, str) for key, value in choices.items()):\n",
    "            return True\n",
    "    except json.JSONDecodeError:\n",
    "        return False\n",
    "\n",
    "    return False\n",
    "\n",
    "def is_correct_answer_included(correct_answer, generated_choices):\n",
    "    try:\n",
    "        choices_dict = json.loads(generated_choices)\n",
    "        return correct_answer in choices_dict.values()\n",
    "    except json.JSONDecodeError:\n",
    "        return False\n",
    "\n",
    "def is_plausibility_yes(assessment_answer):\n",
    "    \"\"\"Check if the first word of the assessment answer is 'yes'.\"\"\"\n",
    "    return assessment_answer.split()[0].lower() == 'yes'\n",
    "    \n",
    "class AssessQuizChoices(dspy.Signature):\n",
    "    \"\"\"Assess the quality of quiz answer choices along specified dimensions.\"\"\"\n",
    "    \n",
    "    question = dspy.InputField()\n",
    "    answer_choices = dspy.InputField()\n",
    "    assessment_question = dspy.InputField()\n",
    "    assessment_answer = dspy.OutputField(desc=\"Yes or No\")\n",
    "    \n",
    "def format_valid_metric(gold, pred, trace=None):\n",
    "    generated_choices = pred.choices\n",
    "    format_valid = format_checker(generated_choices)\n",
    "    score = format_valid\n",
    "    return score\n",
    "\n",
    "def is_correct_metric(gold, pred, trace=None):\n",
    "    correct_answer, generated_choices = gold.answer, pred.choices\n",
    "    correct_included = is_correct_answer_included(correct_answer, generated_choices)\n",
    "    score = correct_included\n",
    "    return score\n",
    "\n",
    "def plausibility_metric(gold, pred, trace=None):\n",
    "    question, generated_choices = gold.question, pred.choices\n",
    "    plausibility_question = \"Are the distractors in the answer choices plausible and not easily identifiable as incorrect?\"\n",
    "    plausibility_assessment = dspy.Predict(AssessQuizChoices)(question=question, answer_choices=generated_choices, assessment_question=plausibility_question)\n",
    "    plausibility_result = plausibility_assessment.assessment_answer.split()[0].lower() == 'yes'\n",
    "    score = plausibility_result\n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Extrinsic Metrics: Assess the overall quality and effectiveness of generated output on downstream task\n",
    "\n",
    "The extrinsic metric is defined as the overall quality of the generated quiz choices and is evaluated over a composite metric, accounting for these constraints.\n",
    "\n",
    "The composite metric maintains the core intrinsic metrics required for producing a valid set of quiz choices in validating valid formatting and correct answere icnlusion, and the overall composite metric returns an averaged score over the 3 intrinsic metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def overall_metric(gold, pred, trace=None):\n",
    "    question, correct_answer, generated_choices = gold.question, gold.answer, pred.choices\n",
    "    format_valid = format_checker(generated_choices)\n",
    "    correct_included = is_correct_answer_included(correct_answer, generated_choices)\n",
    "    plausibility_question = \"Are the distractors in the answer choices plausible and not easily identifiable as incorrect?\"\n",
    "    plausibility_assessment = dspy.Predict(AssessQuizChoices)(question=question, answer_choices=generated_choices, assessment_question=plausibility_question)\n",
    "    plausibility_result = plausibility_assessment.assessment_answer.split()[0].lower() == 'yes'\n",
    "    score = (format_valid + correct_included + plausibility_result) / 3.0 if correct_included and format_valid else 0\n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We hence define the evaluation as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = [format_valid_metric, is_correct_metric, plausibility_metric, overall_metric]\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(quiz_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at an example quiz choice generation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = devset[67]\n",
    "quiz_choices = quiz_generator(question=example.question, answer = example.answer)\n",
    "print(f'Generated Quiz Choices: ', quiz_choices.choices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset[67:68], num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(quiz_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the generated quiz choices do not maintain valid JSON formatting, which violates the valid formatting and correctness check, even though the choices are noted as plausible. We also see that the correct answer is also labeled by \"(Correct Answer)\", which is not the intention of producing good quiz question answer choices.  \n",
    "\n",
    "Let's take a look at how we can integrate DSPy Assertions and impose constraints to produce better answer choices."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5] Introducing Assertions: QuizAnswerGeneratorWithAssertions\n",
    "Let's include assertions that simply reiterate our computational constraints within DSPy Assertion semantics. \n",
    "\n",
    "In the first **Assertion**, we check for if the generated quiz choices are in JSON format and if not, assert: **\"The format of the answer choices should be in JSON format. Please revise accordingly.\"**\n",
    "\n",
    "We also check for if the set of quiz choices includes the correct answer and ensure this if violated with the feedback message: **\"The answer choices do not include the correct answer to the question. Please revise accordingly.\"**\n",
    "\n",
    "Lastly, we assess if the plausible distractor choices are indeed good distractor options and if not, assert: **\"The answer choices are not plausible distractors or are too easily identifiable as incorrect. Please revise to provide more challenging and plausible distractors.\"**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QuizAnswerGeneratorWithAssertions(dspy.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.generate_choices = dspy.ChainOfThought(GenerateAnswerChoices)\n",
    "\n",
    "    def forward(self, question, answer):\n",
    "        choice_string = self.generate_choices(question=question, correct_answer=answer, number_of_choices=number_of_choices).answer_choices\n",
    "        dspy.Suggest(format_checker(choice_string), \"The format of the answer choices should be in JSON format. Please revise accordingly.\", target_module=GenerateAnswerChoices)\n",
    "        dspy.Suggest(is_correct_answer_included(answer, choice_string), \"The answer choices do not include the correct answer to the question. Please revise accordingly.\", target_module=GenerateAnswerChoices)\n",
    "        plausibility_question = \"Are the distractors in the answer choices plausible and not easily identifiable as incorrect?\"\n",
    "        plausibility_assessment = dspy.Predict(AssessQuizChoices)(question=question, answer_choices=choice_string, assessment_question=plausibility_question)\n",
    "        dspy.Suggest(is_plausibility_yes(plausibility_assessment.assessment_answer), \"The answer choices are not plausible distractors or are too easily identifiable as incorrect. Please revise to provide more challenging and plausible distractors.\", target_module=GenerateAnswerChoices)\n",
    "        return dspy.Prediction(choices = choice_string)\n",
    "\n",
    "number_of_choices = '4'\n",
    "quiz_generator_with_assertions = assert_transform_module(QuizAnswerGeneratorWithAssertions().map_named_predictors(Retry), backtrack_handler) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's evaluate the `QuizAnswerGeneratorWithAssertions` now over the devset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = [format_valid_metric, is_correct_metric, plausibility_metric, overall_metric]\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(quiz_generator_with_assertions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's take a look at how our generated set of quiz choices has improved with the addition of assertions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = devset[67]\n",
    "quiz_choices = quiz_generator_with_assertions(question=example.question, answer = example.answer)\n",
    "print(f'Generated Quiz Choices: ', quiz_choices.choices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset[67:68], num_threads=1, display_progress=True, display_table=30)\n",
    "    evaluate(quiz_generator_with_assertions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the quiz choices follow all of our constraints!\n",
    "\n",
    "Not only are the answer choices all plausible, and have removed any indicator of what the correct answer could be, but the answer choices now maintain valid JSON formatting with 4 possible answer choices to the question, which includes the correct answer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6] Compilation With Assertions\n",
    "\n",
    "We can leverage **DSPy**'s`BootstrapFewShotWithRandomSearch` optimizer, to automatically generate few-shot demonstrations and conduct a random search over the candidates to output the best compiled program. We evaluate this over the `final_metric` composite metric. \n",
    "\n",
    "We can first evaluate this on `QuizAnswerGenerator` to see how compilation performs without the inclusion of assertions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
    "compiled_quiz_generator = teleprompter.compile(student = quiz_generator, teacher = quiz_generator, trainset=trainset, valset=devset[:100])\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(compiled_quiz_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we test the compilation on 2 settings with assertions:\n",
    "\n",
    "**Compilation with Assertions**: assertion-driven example bootstrapping and counterexample bootstrapping during compilation. Teacher has assertions while the student does not as the student learns from the teacher's assertion-driven bootstrapped examples. \n",
    "\n",
    "**Compilation + Inference with Assertions**: assertion-driven optimizations for both the teacher and student to offer enhanced assertion-driven outputs during both compilation and inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
    "compiled_with_assertions_quiz_generator = teleprompter.compile(student=quiz_generator, teacher = quiz_generator_with_assertions, trainset=trainset, valset=devset[:100])\n",
    "\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(compiled_with_assertions_quiz_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
    "compiled_quiz_generator_with_assertions = teleprompter.compile(student=quiz_generator_with_assertions, teacher = quiz_generator_with_assertions, trainset=trainset, valset=devset[:100])\n",
    "\n",
    "for metric in metrics:\n",
    "    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
    "    evaluate(compiled_quiz_generator_with_assertions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dspy_dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
